{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "from torchvision.utils import make_grid\n",
    "from torch_tools.visualization import to_image\n",
    "from visualization import interpolate\n",
    "from loading import load_from_dir\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "deformator, G, shift_predictor = load_from_dir(\n",
    "    './models/pretrained/deformators/SN_MNIST/',\n",
    "    G_weights='./models/pretrained/generators/SN_MNIST/')\n",
    "\n",
    "# deformator, G, shift_predictor = load_from_dir(\n",
    "#     './models/pretrained/deformators/SN_Anime/',\n",
    "#     G_weights='./models/pretrained/generators/SN_Anime/')\n",
    "\n",
    "# deformator, G, shift_predictor = load_from_dir(\n",
    "#     './models/pretrained/deformators/BigGAN/',\n",
    "#     G_weights='./models/pretrained/generators/BigGAN/G_ema.pth')\n",
    "\n",
    "# deformator, G, shift_predictor = load_from_dir(\n",
    "#     './models/pretrained/deformators/ProgGAN/',\n",
    "#     G_weights='./models/pretrained/generators/ProgGAN/100_celeb_hq_network-snapshot-010403.pth')\n",
    "\n",
    "# deformator, G, shift_predictor = load_from_dir(\n",
    "#     './models/pretrained/deformators/StyleGAN2/',\n",
    "#     G_weights='./models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt')\n",
    "\n",
    "discovered_annotation = ''\n",
    "for d in deformator.annotation.items():\n",
    "    discovered_annotation += '{}: {}\\n'.format(d[0], d[1])\n",
    "print('human-annotated directions:\\n' + discovered_annotation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import is_conditional\n",
    "\n",
    "rows = 8\n",
    "plt.figure(figsize=(5, rows), dpi=250)\n",
    "\n",
    "# set desired class for conditional GAN\n",
    "if is_conditional(G):\n",
    "    G.set_classes(12)\n",
    "\n",
    "annotated = list(deformator.annotation.values())\n",
    "inspection_dim = annotated[0]\n",
    "zs = torch.randn([rows, G.dim_z] if type(G.dim_z) == int else [rows] + G.dim_z, device='cuda')\n",
    "\n",
    "\n",
    "for z, i in zip(zs, range(rows)):\n",
    "    interpolation_deformed = interpolate(\n",
    "        G, z.unsqueeze(0),\n",
    "        shifts_r=16,\n",
    "        shifts_count=3,\n",
    "        dim=inspection_dim,\n",
    "        deformator=deformator,\n",
    "        with_central_border=True)\n",
    "\n",
    "    plt.subplot(rows, 1, i + 1)\n",
    "    plt.axis('off')\n",
    "    grid = make_grid(interpolation_deformed, nrow=11, padding=1, pad_value=0.0)\n",
    "    grid = torch.clamp(grid, -1, 1)\n",
    "\n",
    "    plt.imshow(to_image(grid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
